package com.nx.platform.es.biz.esspider.expression;

import com.google.common.primitives.Longs;
import com.nx.platform.es.common.utils.MoreMaps;
import org.antlr.v4.runtime.tree.TerminalNode;

import java.util.Map;

/**
 * @author
 * @date 2018/01/5
 */
public class ExpressionVisitorImpl extends ExpressionBaseVisitor<Object> {

    private static final String CURRENT_TIME = "CURRENT_TIME";
    private final Map<String, Object> params;

    public ExpressionVisitorImpl(Map<String, Object> params) {
        this.params = params;
    }

    private long nullToZero(Long number) {
        return number == null ? 0 : number;
    }

    private Long getParam(String key) {
        if (CURRENT_TIME.equals(key)) {
            return System.currentTimeMillis();
        }
        return MoreMaps.getLong(params, key);
    }

    @Override
    public Boolean visitCompare(ExpressionParser.CompareContext ctx) {
        // left
        Long left = (Long) visit(ctx.val(0));
        if (left == null) {
            return false;
        }
        // right
        Long right = (Long) visit(ctx.val(1));
        if (right == null) {
            return false;
        }
        // compare
        switch (ctx.op.getType()) {
        case ExpressionParser.EQ:
            return left.compareTo(right) == 0;
        case ExpressionParser.NE:
            return left.compareTo(right) != 0;
        case ExpressionParser.GT:
            return left.compareTo(right) > 0;
        case ExpressionParser.GE:
            return left.compareTo(right) >= 0;
        case ExpressionParser.LT:
            return left.compareTo(right) < 0;
        case ExpressionParser.LE:
            return left.compareTo(right) <= 0;
        default:
            return false;
        }
    }

    @Override
    public Boolean visitExsit(ExpressionParser.ExsitContext ctx) {
        return MoreMaps.getObject(params, ctx.ID().getText()) != null;
    }

    @Override
    public Boolean visitNot(ExpressionParser.NotContext ctx) {
        return !(Boolean) visit(ctx.expr());
    }

    @Override
    public Boolean visitParensExpr(ExpressionParser.ParensExprContext ctx) {
        return (Boolean) visit(ctx.expr());
    }

    @Override
    public Boolean visitOr(ExpressionParser.OrContext ctx) {
        return (Boolean) visit(ctx.expr(0)) || (Boolean) visit(ctx.expr(1));
    }

    @Override
    public Boolean visitIn(ExpressionParser.InContext ctx) {
        Long left = getParam(ctx.ID().getText());
        return left != null && ctx.NUM().stream()
                .map(TerminalNode::getText)
                .map(Longs::tryParse)
                .anyMatch(left::equals);
    }

    @Override
    public Boolean visitAnd(ExpressionParser.AndContext ctx) {
        return (Boolean) visit(ctx.expr(0)) && (Boolean) visit(ctx.expr(1));
    }

    @Override
    public Long visitNumber(ExpressionParser.NumberContext ctx) {
        return Longs.tryParse(ctx.NUM().getText());
    }

    @Override
    public Long visitBitAndOr(ExpressionParser.BitAndOrContext ctx) {
        long left = nullToZero((Long) visit(ctx.val(0)));
        long right = nullToZero((Long) visit(ctx.val(1)));
        if (ctx.op.getType() == ExpressionParser.BITAND) {
            return left & right;
        }
        return left | right;
    }

    @Override
    public Long visitAddSub(ExpressionParser.AddSubContext ctx) {
        long left = nullToZero((Long) visit(ctx.val(0)));
        long right = nullToZero((Long) visit(ctx.val(1)));
        if (ctx.op.getType() == ExpressionParser.ADD) {
            return left + right;
        }
        return left - right;
    }

    @Override
    public Long visitParams(ExpressionParser.ParamsContext ctx) {
        return getParam(ctx.ID().getText());
    }

    @Override
    public Long visitParensVal(ExpressionParser.ParensValContext ctx) {
        return (Long) visit(ctx.val());
    }

    @Override
    public Long visitMuldiv(ExpressionParser.MuldivContext ctx) {
        long left = nullToZero((Long) visit(ctx.val(0)));
        long right = nullToZero((Long) visit(ctx.val(1)));
        if (ctx.op.getType() == ExpressionParser.MUL) {
            return left * right;
        }
        return right == 0L ? Long.MAX_VALUE : left / right;
    }

}
